# experiment1就是baseline+part
# 5way-5shot-15query-600个task, 平均acc: 0.8326

from lib.algorithms import ALGORITHMS
from .experiment1_net import Experiment1
from .experiment1_loss import Experiment1Loss
from .experiment1_metric import Experiment1Acc
from lib.sampler import NormalSampler
from lib.sampler import EpisodeSampler


@ALGORITHMS.register('experiment1')
def expirement1(cfg):
    return {
        'net': Experiment1(cfg),
        'loss': Experiment1Loss(cfg),
        'metric': Experiment1Acc(cfg),
        'train_sampler': NormalSampler,
        'validate_sampler': NormalSampler,
        'test_sampler': EpisodeSampler
    }
